#include "easy_grpc/easy_grpc.h"

#include "generated/test.egrpc.pb.h"

#include <fstream>
#include "certs.h"
#include "gtest/gtest.h"

namespace rpc = easy_grpc;

namespace {

class Test_sync_impl : public tests::TestService {
 public:
  ::rpc::Future<::tests::TestReply> TestMethod(::tests::TestRequest req, ::easy_grpc::Context) override {
    ::tests::TestReply result;
    result.set_name(req.name() + "_replied");

    ::rpc::Promise<::tests::TestReply> prom;
    auto f = prom.get_future();
    prom.set_value(result);

    return f;
  }
};

}  // namespace

TEST(test_easy_grpc, simple_rpc) {
  rpc::Environment grpc_env;

  std::array<rpc::Completion_queue, 1> server_queues;
  rpc::Completion_queue client_queue;

  Test_sync_impl sync_srv;

  int server_port = 0;
  rpc::server::Server server = std::move(rpc::server::Config()
                                             .add_default_listening_queues({server_queues.begin(), server_queues.end()})
                                             .add_service(sync_srv)
                                             .add_listening_port("127.0.0.1:0", {}, &server_port));

  EXPECT_NE(0, server_port);

  {
    rpc::client::Unsecure_channel channel(std::string("127.0.0.1:") + std::to_string(server_port), &client_queue);
    tests::TestService::Stub stub(&channel);

    ::tests::TestRequest req;
    req.set_name("dude");
    EXPECT_EQ(stub.TestMethod(req).get().name(), "dude_replied");
  }
}

TEST(test_easy_grpc, secure_rpc_full) {
  const std::string server_url = "cogmentserver.com:";
  rpc::Environment grpc_env;

  int server_port = 0;
  std::array<rpc::Completion_queue, 1> server_queues;
  Test_sync_impl sync_srv;
  auto creds = std::make_shared<rpc::server::Credentials>(ROOT_CERT, SERVER_PRIVATE_KEY, SERVER_TRUST_CHAIN);
  rpc::server::Server server = std::move(rpc::server::Config()
                                             .add_default_listening_queues({server_queues.begin(), server_queues.end()})
                                             .add_service(sync_srv)
                                             .add_listening_port(server_url + "0", creds, &server_port));

  EXPECT_NE(0, server_port);

  {
    auto creds = rpc::client::Credentials(ROOT_CERT, CLIENT_PRIVATE_KEY, CLIENT_TRUST_CHAIN);
    rpc::Completion_queue client_queue;
    rpc::client::Secure_channel channel(server_url + std::to_string(server_port), &client_queue, &creds);
    tests::TestService::Stub stub(&channel);

    ::tests::TestRequest req;
    req.set_name("dude");
    EXPECT_EQ(stub.TestMethod(req).get().name(), "dude_replied");
  }
}

TEST(test_easy_grpc, secure_rpc_partial) {
  const std::string server_url = "cogmentserver.com:";
  rpc::Environment grpc_env;

  int server_port = 0;
  std::array<rpc::Completion_queue, 1> server_queues;
  Test_sync_impl sync_srv;
  auto creds = std::make_shared<rpc::server::Credentials>(
      nullptr, SERVER_PRIVATE_KEY, SERVER_TRUST_CHAIN,
      grpc_ssl_client_certificate_request_type::GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE);
  rpc::server::Server server = std::move(rpc::server::Config()
                                             .add_default_listening_queues({server_queues.begin(), server_queues.end()})
                                             .add_service(sync_srv)
                                             .add_listening_port(server_url + "0", creds, &server_port));

  EXPECT_NE(0, server_port);

  {
    auto creds = rpc::client::Credentials(ROOT_CERT, CLIENT_PRIVATE_KEY, CLIENT_TRUST_CHAIN);
    rpc::Completion_queue client_queue;
    rpc::client::Secure_channel channel(server_url + std::to_string(server_port), &client_queue, &creds);
    tests::TestService::Stub stub(&channel);

    ::tests::TestRequest req;
    req.set_name("dude");
    EXPECT_EQ(stub.TestMethod(req).get().name(), "dude_replied");
  }

  {
    auto creds = rpc::client::Credentials(ROOT_CERT, nullptr, nullptr);
    rpc::Completion_queue client_queue;
    rpc::client::Secure_channel channel(server_url + std::to_string(server_port), &client_queue, &creds);
    tests::TestService::Stub stub(&channel);

    ::tests::TestRequest req;
    req.set_name("yo");
    EXPECT_EQ(stub.TestMethod(req).get().name(), "yo_replied");
  }
}

TEST(test_easy_grpc, big_volume) {
  rpc::Environment grpc_env;

  constexpr int receiving_threads = 3;
  constexpr int sending_threads = 3;
  constexpr int rpcs_to_send = 10000;

  std::array<rpc::Completion_queue, receiving_threads> server_queues;
  std::array<rpc::Completion_queue, sending_threads> client_queues;

  Test_sync_impl sync_srv;

  int server_port = 0;
  rpc::server::Server server = std::move(rpc::server::Config()
                                             .add_default_listening_queues({server_queues.begin(), server_queues.end()})
                                             .add_service(sync_srv)
                                             .add_listening_port("127.0.0.1:0", {}, &server_port));

  EXPECT_NE(0, server_port);

  rpc::client::Unsecure_channel channel(std::string("127.0.0.1:") + std::to_string(server_port), nullptr);
  tests::TestService::Stub stub(&channel);

  ::tests::TestRequest req;
  req.set_name("dude");

  std::vector<rpc::Future<::tests::TestReply>> results;
  results.reserve(rpcs_to_send);

  for (int i = 0; i < rpcs_to_send; ++i) {
    rpc::client::Call_options options;
    options.completion_queue = &client_queues[i % sending_threads];
    results.emplace_back(stub.TestMethod(req, options));
  }

  for (auto& f : results) {
    EXPECT_EQ(f.get().name(), "dude_replied");
  }
}

TEST(test_easy_grpc, rpc_with_headers) {
  rpc::Environment grpc_env;

  std::array<rpc::Completion_queue, 1> server_queues;
  rpc::Completion_queue client_queue;

  Test_sync_impl sync_srv;

  int server_port = 0;
  rpc::server::Server server = std::move(rpc::server::Config()
                                             .add_default_listening_queues({server_queues.begin(), server_queues.end()})
                                             .add_service(sync_srv)
                                             .add_listening_port("127.0.0.1:0", {}, &server_port));

  EXPECT_NE(0, server_port);

  {
    rpc::client::Unsecure_channel channel(std::string("127.0.0.1:") + std::to_string(server_port), &client_queue);
    tests::TestService::Stub stub(&channel);

    ::tests::TestRequest req;
    req.set_name("dude");

    rpc::client::Call_options options;
    std::vector<grpc_metadata> headers;

    grpc_metadata trial_header;
    trial_header.key = grpc_slice_from_static_string("trial_id");
    trial_header.value = grpc_slice_from_static_string("123567890");
    headers.push_back(trial_header);
    options.headers = &headers;

    EXPECT_EQ(stub.TestMethod(req, options).get().name(), "dude_replied");

    grpc_slice_unref(trial_header.key);
    grpc_slice_unref(trial_header.value);
  }
}
